{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# use Onnx to load converted model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### onnx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph torch-jit-export (\n",
      "  %1[FLOAT, 1x3x416x416]\n",
      ") initializers (\n",
      "  %2[FLOAT, 32x3x3x3]\n",
      "  %3[FLOAT, 32]\n",
      "  %4[FLOAT, 32]\n",
      "  %5[FLOAT, 32]\n",
      "  %6[FLOAT, 32]\n",
      "  %7[FLOAT, 64x32x3x3]\n",
      "  %8[FLOAT, 64]\n",
      "  %9[FLOAT, 64]\n",
      "  %10[FLOAT, 64]\n",
      "  %11[FLOAT, 64]\n",
      "  %12[FLOAT, 128x64x3x3]\n",
      "  %13[FLOAT, 128]\n",
      "  %14[FLOAT, 128]\n",
      "  %15[FLOAT, 128]\n",
      "  %16[FLOAT, 128]\n",
      "  %17[FLOAT, 64x128x1x1]\n",
      "  %18[FLOAT, 64]\n",
      "  %19[FLOAT, 64]\n",
      "  %20[FLOAT, 64]\n",
      "  %21[FLOAT, 64]\n",
      "  %22[FLOAT, 128x64x3x3]\n",
      "  %23[FLOAT, 128]\n",
      "  %24[FLOAT, 128]\n",
      "  %25[FLOAT, 128]\n",
      "  %26[FLOAT, 128]\n",
      "  %27[FLOAT, 256x128x3x3]\n",
      "  %28[FLOAT, 256]\n",
      "  %29[FLOAT, 256]\n",
      "  %30[FLOAT, 256]\n",
      "  %31[FLOAT, 256]\n",
      "  %32[FLOAT, 128x256x1x1]\n",
      "  %33[FLOAT, 128]\n",
      "  %34[FLOAT, 128]\n",
      "  %35[FLOAT, 128]\n",
      "  %36[FLOAT, 128]\n",
      "  %37[FLOAT, 256x128x3x3]\n",
      "  %38[FLOAT, 256]\n",
      "  %39[FLOAT, 256]\n",
      "  %40[FLOAT, 256]\n",
      "  %41[FLOAT, 256]\n",
      "  %42[FLOAT, 512x256x3x3]\n",
      "  %43[FLOAT, 512]\n",
      "  %44[FLOAT, 512]\n",
      "  %45[FLOAT, 512]\n",
      "  %46[FLOAT, 512]\n",
      "  %47[FLOAT, 256x512x1x1]\n",
      "  %48[FLOAT, 256]\n",
      "  %49[FLOAT, 256]\n",
      "  %50[FLOAT, 256]\n",
      "  %51[FLOAT, 256]\n",
      "  %52[FLOAT, 512x256x3x3]\n",
      "  %53[FLOAT, 512]\n",
      "  %54[FLOAT, 512]\n",
      "  %55[FLOAT, 512]\n",
      "  %56[FLOAT, 512]\n",
      "  %57[FLOAT, 256x512x1x1]\n",
      "  %58[FLOAT, 256]\n",
      "  %59[FLOAT, 256]\n",
      "  %60[FLOAT, 256]\n",
      "  %61[FLOAT, 256]\n",
      "  %62[FLOAT, 512x256x3x3]\n",
      "  %63[FLOAT, 512]\n",
      "  %64[FLOAT, 512]\n",
      "  %65[FLOAT, 512]\n",
      "  %66[FLOAT, 512]\n",
      "  %67[FLOAT, 1024x512x3x3]\n",
      "  %68[FLOAT, 1024]\n",
      "  %69[FLOAT, 1024]\n",
      "  %70[FLOAT, 1024]\n",
      "  %71[FLOAT, 1024]\n",
      "  %72[FLOAT, 512x1024x1x1]\n",
      "  %73[FLOAT, 512]\n",
      "  %74[FLOAT, 512]\n",
      "  %75[FLOAT, 512]\n",
      "  %76[FLOAT, 512]\n",
      "  %77[FLOAT, 1024x512x3x3]\n",
      "  %78[FLOAT, 1024]\n",
      "  %79[FLOAT, 1024]\n",
      "  %80[FLOAT, 1024]\n",
      "  %81[FLOAT, 1024]\n",
      "  %82[FLOAT, 512x1024x1x1]\n",
      "  %83[FLOAT, 512]\n",
      "  %84[FLOAT, 512]\n",
      "  %85[FLOAT, 512]\n",
      "  %86[FLOAT, 512]\n",
      "  %87[FLOAT, 1024x512x3x3]\n",
      "  %88[FLOAT, 1024]\n",
      "  %89[FLOAT, 1024]\n",
      "  %90[FLOAT, 1024]\n",
      "  %91[FLOAT, 1024]\n",
      "  %92[FLOAT, 1024x1024x3x3]\n",
      "  %93[FLOAT, 1024]\n",
      "  %94[FLOAT, 1024]\n",
      "  %95[FLOAT, 1024]\n",
      "  %96[FLOAT, 1024]\n",
      "  %97[FLOAT, 1024x1024x3x3]\n",
      "  %98[FLOAT, 1024]\n",
      "  %99[FLOAT, 1024]\n",
      "  %100[FLOAT, 1024]\n",
      "  %101[FLOAT, 1024]\n",
      "  %102[FLOAT, 64x512x1x1]\n",
      "  %103[FLOAT, 64]\n",
      "  %104[FLOAT, 64]\n",
      "  %105[FLOAT, 64]\n",
      "  %106[FLOAT, 64]\n",
      "  %107[FLOAT, 1024x1280x3x3]\n",
      "  %108[FLOAT, 1024]\n",
      "  %109[FLOAT, 1024]\n",
      "  %110[FLOAT, 1024]\n",
      "  %111[FLOAT, 1024]\n",
      "  %112[FLOAT, 425x1024x1x1]\n",
      "  %113[FLOAT, 425]\n",
      ") {\n",
      "  %115 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%1, %2)\n",
      "  %117 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%115, %3, %4, %5, %6)\n",
      "  %118 = LeakyRelu[alpha = 0.100000001490116](%117)\n",
      "  %119 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%118)\n",
      "  %121 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%119, %7)\n",
      "  %123 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%121, %8, %9, %10, %11)\n",
      "  %124 = LeakyRelu[alpha = 0.100000001490116](%123)\n",
      "  %125 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%124)\n",
      "  %127 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%125, %12)\n",
      "  %129 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%127, %13, %14, %15, %16)\n",
      "  %130 = LeakyRelu[alpha = 0.100000001490116](%129)\n",
      "  %132 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%130, %17)\n",
      "  %134 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%132, %18, %19, %20, %21)\n",
      "  %135 = LeakyRelu[alpha = 0.100000001490116](%134)\n",
      "  %137 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%135, %22)\n",
      "  %139 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%137, %23, %24, %25, %26)\n",
      "  %140 = LeakyRelu[alpha = 0.100000001490116](%139)\n",
      "  %141 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%140)\n",
      "  %143 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%141, %27)\n",
      "  %145 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%143, %28, %29, %30, %31)\n",
      "  %146 = LeakyRelu[alpha = 0.100000001490116](%145)\n",
      "  %148 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%146, %32)\n",
      "  %150 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%148, %33, %34, %35, %36)\n",
      "  %151 = LeakyRelu[alpha = 0.100000001490116](%150)\n",
      "  %153 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%151, %37)\n",
      "  %155 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%153, %38, %39, %40, %41)\n",
      "  %156 = LeakyRelu[alpha = 0.100000001490116](%155)\n",
      "  %157 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%156)\n",
      "  %159 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%157, %42)\n",
      "  %161 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%159, %43, %44, %45, %46)\n",
      "  %162 = LeakyRelu[alpha = 0.100000001490116](%161)\n",
      "  %164 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%162, %47)\n",
      "  %166 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%164, %48, %49, %50, %51)\n",
      "  %167 = LeakyRelu[alpha = 0.100000001490116](%166)\n",
      "  %169 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%167, %52)\n",
      "  %171 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%169, %53, %54, %55, %56)\n",
      "  %172 = LeakyRelu[alpha = 0.100000001490116](%171)\n",
      "  %174 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%172, %57)\n",
      "  %176 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%174, %58, %59, %60, %61)\n",
      "  %177 = LeakyRelu[alpha = 0.100000001490116](%176)\n",
      "  %179 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%177, %62)\n",
      "  %181 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%179, %63, %64, %65, %66)\n",
      "  %182 = LeakyRelu[alpha = 0.100000001490116](%181)\n",
      "  %183 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%182)\n",
      "  %185 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%183, %67)\n",
      "  %187 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%185, %68, %69, %70, %71)\n",
      "  %188 = LeakyRelu[alpha = 0.100000001490116](%187)\n",
      "  %190 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%188, %72)\n",
      "  %192 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%190, %73, %74, %75, %76)\n",
      "  %193 = LeakyRelu[alpha = 0.100000001490116](%192)\n",
      "  %195 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%193, %77)\n",
      "  %197 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%195, %78, %79, %80, %81)\n",
      "  %198 = LeakyRelu[alpha = 0.100000001490116](%197)\n",
      "  %200 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%198, %82)\n",
      "  %202 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%200, %83, %84, %85, %86)\n",
      "  %203 = LeakyRelu[alpha = 0.100000001490116](%202)\n",
      "  %205 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%203, %87)\n",
      "  %207 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%205, %88, %89, %90, %91)\n",
      "  %208 = LeakyRelu[alpha = 0.100000001490116](%207)\n",
      "  %210 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%208, %92)\n",
      "  %212 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%210, %93, %94, %95, %96)\n",
      "  %213 = LeakyRelu[alpha = 0.100000001490116](%212)\n",
      "  %215 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%213, %97)\n",
      "  %217 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%215, %98, %99, %100, %101)\n",
      "  %218 = LeakyRelu[alpha = 0.100000001490116](%217)\n",
      "  %220 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%182, %102)\n",
      "  %222 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%220, %103, %104, %105, %106)\n",
      "  %223 = LeakyRelu[alpha = 0.100000001490116](%222)\n",
      "  %224 = Reshape[shape = [1, 64, 13, 2, 13, 2]](%223)\n",
      "  %225 = Transpose[perm = [0, 1, 2, 4, 3, 5]](%224)\n",
      "  %226 = Reshape[shape = [1, 64, 169, 4]](%225)\n",
      "  %227 = Transpose[perm = [0, 1, 3, 2]](%226)\n",
      "  %228 = Reshape[shape = [1, 64, 4, 13, 13]](%227)\n",
      "  %229 = Transpose[perm = [0, 2, 1, 3, 4]](%228)\n",
      "  %230 = Reshape[shape = [1, 256, 13, 13]](%229)\n",
      "  %231 = Concat[axis = 1](%230, %218)\n",
      "  %233 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%231, %107)\n",
      "  %235 = BatchNormalization[consumed_inputs = [0, 0, 0, 1, 1], epsilon = 9.99999974737875e-06, is_test = 1, momentum = 0.899999976158142](%233, %108, %109, %110, %111)\n",
      "  %236 = LeakyRelu[alpha = 0.100000001490116](%235)\n",
      "  %238 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%236, %112)\n",
      "  %239 = Add[axis = 1, broadcast = 1](%238, %113)\n",
      "  return %239\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import onnx\n",
    "\n",
    "# Load the ONNX model\n",
    "model = onnx.load(\"onnx/yolo2.onnx\")\n",
    "\n",
    "# Check that the IR is well formed\n",
    "onnx.checker.check_model(model)\n",
    "\n",
    "# Print a human readable representation of the graph\n",
    "print(onnx.helper.printable_graph(model.graph))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explore onnx IR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- ref: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto\n",
    "- ref: https://github.com/onnx/onnx/blob/master/onnx/examples/Protobufs.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(onnx_pb2.ModelProto,\n",
       " onnx_pb2.GraphProto,\n",
       " google.protobuf.pyext._message.RepeatedCompositeContainer,\n",
       " 81)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(model), type(model.graph), type(model.graph.node), len(model.graph.node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(onnx_pb2.NodeProto,\n",
       " google.protobuf.pyext._message.RepeatedCompositeContainer,\n",
       " onnx_pb2.AttributeProto,\n",
       " 5)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node = model.graph.node[0]\n",
    "type(node), type(node.attribute), type(node.attribute[0]), len(node.attribute)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### explore node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input: \"1\"\n",
      "input: \"2\"\n",
      "output: \"115\"\n",
      "op_type: \"Conv\"\n",
      "attribute {\n",
      "  name: \"kernel_shape\"\n",
      "  ints: 3\n",
      "  ints: 3\n",
      "  type: INTS\n",
      "}\n",
      "attribute {\n",
      "  name: \"strides\"\n",
      "  ints: 1\n",
      "  ints: 1\n",
      "  type: INTS\n",
      "}\n",
      "attribute {\n",
      "  name: \"pads\"\n",
      "  ints: 1\n",
      "  ints: 1\n",
      "  ints: 1\n",
      "  ints: 1\n",
      "  type: INTS\n",
      "}\n",
      "attribute {\n",
      "  name: \"dilations\"\n",
      "  ints: 1\n",
      "  ints: 1\n",
      "  type: INTS\n",
      "}\n",
      "attribute {\n",
      "  name: \"group\"\n",
      "  i: 1\n",
      "  type: INT\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(node) # Tree-like IR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "u'%115 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%1, %2)'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onnx.helper.printable_node(node) # Flat IR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### explore attribute in node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "attr = node.attribute[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name: \"kernel_shape\"\n",
      "ints: 3\n",
      "ints: 3\n",
      "type: INTS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(attr) # Tree-like IR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "u'kernel_shape = [3, 3]'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onnx.helper.printable_attribute(attr)  # Flat IR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### prepare image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image \n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgfile = './data/dog.jpg'\n",
    "img = Image.open(imgfile).convert('RGB').resize( (416, 416) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_arr = np.array(img)\n",
    "img_arr = np.expand_dims(img_arr, -1)\n",
    "img_arr = np.transpose(img_arr, (3,2,0,1))/255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 3, 416, 416)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_arr.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### prepare model of caffe2/tensorflow backend"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "https://ptorch.com/news/95.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python2.7/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "/usr/local/lib/python2.7/dist-packages/onnx_tf/backend.py:677: UserWarning: Unsupported kernel_shape attribute by Tensorflow in Conv operator. The attribute will be ignored.\n",
      "  UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 38.3 s, sys: 626 ms, total: 39 s\n",
      "Wall time: 38.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# import onnx_caffe2.backend as backend_1\n",
    "import onnx_tf.backend as backend\n",
    "import numpy as np\n",
    "\n",
    "rep = backend.prepare(model, device=\"CUDA:0\") # or \"CPU\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "((1, 425, 13, 13), <class 'onnx.backend.base.Outputs'>)\n",
      "CPU times: user 8.19 s, sys: 992 ms, total: 9.18 s\n",
      "Wall time: 2.99 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# For the Caffe2 backend:\n",
    "#     rep.predict_net is the Caffe2 protobuf for the network\n",
    "#     rep.workspace is the Caffe2 workspace for the network\n",
    "#       (see the class onnx_caffe2.backend.Workspace)\n",
    "outputs = rep.run(img_arr.astype(np.float32))\n",
    "# To run networks with more than one input, pass a tuple\n",
    "# rather than a single numpy ndarray.\n",
    "print(outputs[0].shape, type(outputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 425, 13, 13)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = np.array(outputs).squeeze(0)\n",
    "outputs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load detection information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['num_anchors', 'anchors', 'num_classes']"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pickle\n",
    "detection_information = pickle.load(open('detection_information.pkl','rb'))\n",
    "detection_information.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_anchors, anchors, num_classes = [detection_information[k] for k in detection_information.keys()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### use original pytorch-yolo2 module to decect outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.cuda.FloatTensor"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "output = torch.FloatTensor(outputs).cuda() \n",
    "type(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def detect(output, conf_thresh=0.5, nms_thresh=0.4):\n",
    "\n",
    "    if num_classes == 20:\n",
    "        namesfile = 'data/voc.names'\n",
    "    elif num_classes == 80:\n",
    "        namesfile = 'data/coco.names'\n",
    "    else:\n",
    "        namesfile = 'data/names'\n",
    "    \n",
    "    for i in range(2):\n",
    "        boxes = get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors)[0]\n",
    "        boxes = nms(boxes, nms_thresh)\n",
    "\n",
    "    class_names = load_class_names(namesfile)\n",
    "    plot_boxes(img, boxes, 'predictions.jpg', class_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "truck: 0.934710\n",
      "bicycle: 0.998012\n",
      "dog: 0.990524\n",
      "save plot results to predictions.jpg\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "utils.py:140: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  cls_confs = torch.nn.Softmax()(Variable(output[5:5+num_classes].transpose(0,1))).data\n"
     ]
    }
   ],
   "source": [
    "detect(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
